Simple DDPM¶
In [76]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import matplotlib.pyplot as plt
Cosine Noise Scheduler¶
In [77]:
# total timesteps
T = 1000
In [78]:
# # cosine scheduling (Improved DDPM)
# s = 0.008
# f = torch.cos(torch.pi * (torch.linspace(0, 1, T+1) + s) / (2 + 2*s)) ** 2
# alpha_bar = f / f[0]
# alpha = alpha_bar[1:] / alpha_bar[:-1]
# alpha_bar = alpha_bar[1:]
# beta = 1 - alpha
In [79]:
# linear scheduling
beta_1 = 1e-4 # alpha_1 = 0.9999
beta_T = 0.02 # alpha_T = 0.98
beta = torch.linspace(beta_1, beta_T, T)
alpha = 1.0 - beta
alpha_bar = torch.cumprod(alpha, dim=0)
In [80]:
print(beta.shape)
print(alpha.shape)
print(alpha_bar.shape)
torch.Size([1000]) torch.Size([1000]) torch.Size([1000])
Dataset¶
$$ \sin(2\pi at) + \sin(2\pi bt) \;\; where \;\; a \sim \text{Uniform}(100,110), b \sim \text{Uniform}(1,11) $$
In [81]:
n = 4000 # number of data
a = torch.randint(1000, 1101, (n,), dtype=torch.float32) / 10
b = torch.randint(10, 111, (n,), dtype=torch.float32) / 10
fs = 4000
t = torch.linspace(0, 1, 1*fs)
data = torch.sin(2 * torch.pi * a.reshape(-1,1) @ t.reshape(1,-1)) + torch.sin(2 * torch.pi * b.reshape(-1,1) @ t.reshape(1,-1))
data.shape
Out[81]:
torch.Size([4000, 4000])
In [82]:
# def pattern(d,h,r,k):
# return 4*(torch.sin(k*d*h/r))**2
# n = 1000 # number of data
# d = torch.randint(200, 301, (n,), dtype=torch.float32) / 10
# h = torch.randint(100, 301, (n,), dtype=torch.float32) / 10
# r = torch.randint(160, 201, (n,), dtype=torch.float32)
# k=torch.linspace(0, 4.2, 5120)
# t = torch.linspace(0, 2560, 5120)
# data = pattern(d.reshape(-1,1),h.reshape(-1,1),r.reshape(-1,1),k)
# data.shape
In [83]:
# data examples
for i in range(10):
plt.figure(figsize=(12,5))
plt.plot(t, data[i*10])
plt.grid()
plt.show()
In [84]:
data = data.unsqueeze(dim=1)
data.shape
Out[84]:
torch.Size([4000, 1, 4000])
In [85]:
class MySignalDataset(Dataset):
def __init__(self, data):
"""
data: (num_signals, channel=1, signal_length) 형태의 텐서
"""
self.data = data
def __len__(self):
return self.data.size(0) # 신호의 총 개수 (1000)
def __getitem__(self, idx):
"""
idx번째 신호(1D 텐서)를 반환.
"""
return self.data[idx]
# 1) Dataset 생성
dataset = MySignalDataset(data)
# 2) DataLoader로 배치 단위 생성
# 원하는 batch_size 로 설정하세요 (예: 16)
batch_size = 10
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 3) DataLoader 순회 예시
for i, batch_data in enumerate(dataloader):
# batch_data.shape == (batch_size, signal_length)
print(f"Batch {i} shape: {batch_data.shape}")
# ... 학습 또는 처리 로직 ...
# break # 한 번만 보고 싶다면
Batch 0 shape: torch.Size([10, 1, 4000]) Batch 1 shape: torch.Size([10, 1, 4000]) Batch 2 shape: torch.Size([10, 1, 4000]) Batch 3 shape: torch.Size([10, 1, 4000]) Batch 4 shape: torch.Size([10, 1, 4000]) Batch 5 shape: torch.Size([10, 1, 4000]) Batch 6 shape: torch.Size([10, 1, 4000]) Batch 7 shape: torch.Size([10, 1, 4000]) Batch 8 shape: torch.Size([10, 1, 4000]) Batch 9 shape: torch.Size([10, 1, 4000]) Batch 10 shape: torch.Size([10, 1, 4000]) Batch 11 shape: torch.Size([10, 1, 4000]) Batch 12 shape: torch.Size([10, 1, 4000]) Batch 13 shape: torch.Size([10, 1, 4000]) Batch 14 shape: torch.Size([10, 1, 4000]) Batch 15 shape: torch.Size([10, 1, 4000]) Batch 16 shape: torch.Size([10, 1, 4000]) Batch 17 shape: torch.Size([10, 1, 4000]) Batch 18 shape: torch.Size([10, 1, 4000]) Batch 19 shape: torch.Size([10, 1, 4000]) Batch 20 shape: torch.Size([10, 1, 4000]) Batch 21 shape: torch.Size([10, 1, 4000]) Batch 22 shape: torch.Size([10, 1, 4000]) Batch 23 shape: torch.Size([10, 1, 4000]) Batch 24 shape: torch.Size([10, 1, 4000]) Batch 25 shape: torch.Size([10, 1, 4000]) Batch 26 shape: torch.Size([10, 1, 4000]) Batch 27 shape: torch.Size([10, 1, 4000]) Batch 28 shape: torch.Size([10, 1, 4000]) Batch 29 shape: torch.Size([10, 1, 4000]) Batch 30 shape: torch.Size([10, 1, 4000]) Batch 31 shape: torch.Size([10, 1, 4000]) Batch 32 shape: torch.Size([10, 1, 4000]) Batch 33 shape: torch.Size([10, 1, 4000]) Batch 34 shape: torch.Size([10, 1, 4000]) Batch 35 shape: torch.Size([10, 1, 4000]) Batch 36 shape: torch.Size([10, 1, 4000]) Batch 37 shape: torch.Size([10, 1, 4000]) Batch 38 shape: torch.Size([10, 1, 4000]) Batch 39 shape: torch.Size([10, 1, 4000]) Batch 40 shape: torch.Size([10, 1, 4000]) Batch 41 shape: torch.Size([10, 1, 4000]) Batch 42 shape: torch.Size([10, 1, 4000]) Batch 43 shape: torch.Size([10, 1, 4000]) Batch 44 shape: torch.Size([10, 1, 4000]) Batch 45 shape: torch.Size([10, 1, 4000]) Batch 46 shape: torch.Size([10, 1, 4000]) Batch 47 shape: torch.Size([10, 1, 4000]) Batch 48 shape: torch.Size([10, 1, 4000]) Batch 49 shape: torch.Size([10, 1, 4000]) Batch 50 shape: torch.Size([10, 1, 4000]) Batch 51 shape: torch.Size([10, 1, 4000]) Batch 52 shape: torch.Size([10, 1, 4000]) Batch 53 shape: torch.Size([10, 1, 4000]) Batch 54 shape: torch.Size([10, 1, 4000]) Batch 55 shape: torch.Size([10, 1, 4000]) Batch 56 shape: torch.Size([10, 1, 4000]) Batch 57 shape: torch.Size([10, 1, 4000]) Batch 58 shape: torch.Size([10, 1, 4000]) Batch 59 shape: torch.Size([10, 1, 4000]) Batch 60 shape: torch.Size([10, 1, 4000]) Batch 61 shape: torch.Size([10, 1, 4000]) Batch 62 shape: torch.Size([10, 1, 4000]) Batch 63 shape: torch.Size([10, 1, 4000]) Batch 64 shape: torch.Size([10, 1, 4000]) Batch 65 shape: torch.Size([10, 1, 4000]) Batch 66 shape: torch.Size([10, 1, 4000]) Batch 67 shape: torch.Size([10, 1, 4000]) Batch 68 shape: torch.Size([10, 1, 4000]) Batch 69 shape: torch.Size([10, 1, 4000]) Batch 70 shape: torch.Size([10, 1, 4000]) Batch 71 shape: torch.Size([10, 1, 4000]) Batch 72 shape: torch.Size([10, 1, 4000]) Batch 73 shape: torch.Size([10, 1, 4000]) Batch 74 shape: torch.Size([10, 1, 4000]) Batch 75 shape: torch.Size([10, 1, 4000]) Batch 76 shape: torch.Size([10, 1, 4000]) Batch 77 shape: torch.Size([10, 1, 4000]) Batch 78 shape: torch.Size([10, 1, 4000]) Batch 79 shape: torch.Size([10, 1, 4000]) Batch 80 shape: torch.Size([10, 1, 4000]) Batch 81 shape: torch.Size([10, 1, 4000]) Batch 82 shape: torch.Size([10, 1, 4000]) Batch 83 shape: torch.Size([10, 1, 4000]) Batch 84 shape: torch.Size([10, 1, 4000]) Batch 85 shape: torch.Size([10, 1, 4000]) Batch 86 shape: torch.Size([10, 1, 4000]) Batch 87 shape: torch.Size([10, 1, 4000]) Batch 88 shape: torch.Size([10, 1, 4000]) Batch 89 shape: torch.Size([10, 1, 4000]) Batch 90 shape: torch.Size([10, 1, 4000]) Batch 91 shape: torch.Size([10, 1, 4000]) Batch 92 shape: torch.Size([10, 1, 4000]) Batch 93 shape: torch.Size([10, 1, 4000]) Batch 94 shape: torch.Size([10, 1, 4000]) Batch 95 shape: torch.Size([10, 1, 4000]) Batch 96 shape: torch.Size([10, 1, 4000]) Batch 97 shape: torch.Size([10, 1, 4000]) Batch 98 shape: torch.Size([10, 1, 4000]) Batch 99 shape: torch.Size([10, 1, 4000]) Batch 100 shape: torch.Size([10, 1, 4000]) Batch 101 shape: torch.Size([10, 1, 4000]) Batch 102 shape: torch.Size([10, 1, 4000]) Batch 103 shape: torch.Size([10, 1, 4000]) Batch 104 shape: torch.Size([10, 1, 4000]) Batch 105 shape: torch.Size([10, 1, 4000]) Batch 106 shape: torch.Size([10, 1, 4000]) Batch 107 shape: torch.Size([10, 1, 4000]) Batch 108 shape: torch.Size([10, 1, 4000]) Batch 109 shape: torch.Size([10, 1, 4000]) Batch 110 shape: torch.Size([10, 1, 4000]) Batch 111 shape: torch.Size([10, 1, 4000]) Batch 112 shape: torch.Size([10, 1, 4000]) Batch 113 shape: torch.Size([10, 1, 4000]) Batch 114 shape: torch.Size([10, 1, 4000]) Batch 115 shape: torch.Size([10, 1, 4000]) Batch 116 shape: torch.Size([10, 1, 4000]) Batch 117 shape: torch.Size([10, 1, 4000]) Batch 118 shape: torch.Size([10, 1, 4000]) Batch 119 shape: torch.Size([10, 1, 4000]) Batch 120 shape: torch.Size([10, 1, 4000]) Batch 121 shape: torch.Size([10, 1, 4000]) Batch 122 shape: torch.Size([10, 1, 4000]) Batch 123 shape: torch.Size([10, 1, 4000]) Batch 124 shape: torch.Size([10, 1, 4000]) Batch 125 shape: torch.Size([10, 1, 4000]) Batch 126 shape: torch.Size([10, 1, 4000]) Batch 127 shape: torch.Size([10, 1, 4000]) Batch 128 shape: torch.Size([10, 1, 4000]) Batch 129 shape: torch.Size([10, 1, 4000]) Batch 130 shape: torch.Size([10, 1, 4000]) Batch 131 shape: torch.Size([10, 1, 4000]) Batch 132 shape: torch.Size([10, 1, 4000]) Batch 133 shape: torch.Size([10, 1, 4000]) Batch 134 shape: torch.Size([10, 1, 4000]) Batch 135 shape: torch.Size([10, 1, 4000]) Batch 136 shape: torch.Size([10, 1, 4000]) Batch 137 shape: torch.Size([10, 1, 4000]) Batch 138 shape: torch.Size([10, 1, 4000]) Batch 139 shape: torch.Size([10, 1, 4000]) Batch 140 shape: torch.Size([10, 1, 4000]) Batch 141 shape: torch.Size([10, 1, 4000]) Batch 142 shape: torch.Size([10, 1, 4000]) Batch 143 shape: torch.Size([10, 1, 4000]) Batch 144 shape: torch.Size([10, 1, 4000]) Batch 145 shape: torch.Size([10, 1, 4000]) Batch 146 shape: torch.Size([10, 1, 4000]) Batch 147 shape: torch.Size([10, 1, 4000]) Batch 148 shape: torch.Size([10, 1, 4000]) Batch 149 shape: torch.Size([10, 1, 4000]) Batch 150 shape: torch.Size([10, 1, 4000]) Batch 151 shape: torch.Size([10, 1, 4000]) Batch 152 shape: torch.Size([10, 1, 4000]) Batch 153 shape: torch.Size([10, 1, 4000]) Batch 154 shape: torch.Size([10, 1, 4000]) Batch 155 shape: torch.Size([10, 1, 4000]) Batch 156 shape: torch.Size([10, 1, 4000]) Batch 157 shape: torch.Size([10, 1, 4000]) Batch 158 shape: torch.Size([10, 1, 4000]) Batch 159 shape: torch.Size([10, 1, 4000]) Batch 160 shape: torch.Size([10, 1, 4000]) Batch 161 shape: torch.Size([10, 1, 4000]) Batch 162 shape: torch.Size([10, 1, 4000]) Batch 163 shape: torch.Size([10, 1, 4000]) Batch 164 shape: torch.Size([10, 1, 4000]) Batch 165 shape: torch.Size([10, 1, 4000]) Batch 166 shape: torch.Size([10, 1, 4000]) Batch 167 shape: torch.Size([10, 1, 4000]) Batch 168 shape: torch.Size([10, 1, 4000]) Batch 169 shape: torch.Size([10, 1, 4000]) Batch 170 shape: torch.Size([10, 1, 4000]) Batch 171 shape: torch.Size([10, 1, 4000]) Batch 172 shape: torch.Size([10, 1, 4000]) Batch 173 shape: torch.Size([10, 1, 4000]) Batch 174 shape: torch.Size([10, 1, 4000]) Batch 175 shape: torch.Size([10, 1, 4000]) Batch 176 shape: torch.Size([10, 1, 4000]) Batch 177 shape: torch.Size([10, 1, 4000]) Batch 178 shape: torch.Size([10, 1, 4000]) Batch 179 shape: torch.Size([10, 1, 4000]) Batch 180 shape: torch.Size([10, 1, 4000]) Batch 181 shape: torch.Size([10, 1, 4000]) Batch 182 shape: torch.Size([10, 1, 4000]) Batch 183 shape: torch.Size([10, 1, 4000]) Batch 184 shape: torch.Size([10, 1, 4000]) Batch 185 shape: torch.Size([10, 1, 4000]) Batch 186 shape: torch.Size([10, 1, 4000]) Batch 187 shape: torch.Size([10, 1, 4000]) Batch 188 shape: torch.Size([10, 1, 4000]) Batch 189 shape: torch.Size([10, 1, 4000]) Batch 190 shape: torch.Size([10, 1, 4000]) Batch 191 shape: torch.Size([10, 1, 4000]) Batch 192 shape: torch.Size([10, 1, 4000]) Batch 193 shape: torch.Size([10, 1, 4000]) Batch 194 shape: torch.Size([10, 1, 4000]) Batch 195 shape: torch.Size([10, 1, 4000]) Batch 196 shape: torch.Size([10, 1, 4000]) Batch 197 shape: torch.Size([10, 1, 4000]) Batch 198 shape: torch.Size([10, 1, 4000]) Batch 199 shape: torch.Size([10, 1, 4000]) Batch 200 shape: torch.Size([10, 1, 4000]) Batch 201 shape: torch.Size([10, 1, 4000]) Batch 202 shape: torch.Size([10, 1, 4000]) Batch 203 shape: torch.Size([10, 1, 4000]) Batch 204 shape: torch.Size([10, 1, 4000]) Batch 205 shape: torch.Size([10, 1, 4000]) Batch 206 shape: torch.Size([10, 1, 4000]) Batch 207 shape: torch.Size([10, 1, 4000]) Batch 208 shape: torch.Size([10, 1, 4000]) Batch 209 shape: torch.Size([10, 1, 4000]) Batch 210 shape: torch.Size([10, 1, 4000]) Batch 211 shape: torch.Size([10, 1, 4000]) Batch 212 shape: torch.Size([10, 1, 4000]) Batch 213 shape: torch.Size([10, 1, 4000]) Batch 214 shape: torch.Size([10, 1, 4000]) Batch 215 shape: torch.Size([10, 1, 4000]) Batch 216 shape: torch.Size([10, 1, 4000]) Batch 217 shape: torch.Size([10, 1, 4000]) Batch 218 shape: torch.Size([10, 1, 4000]) Batch 219 shape: torch.Size([10, 1, 4000]) Batch 220 shape: torch.Size([10, 1, 4000]) Batch 221 shape: torch.Size([10, 1, 4000]) Batch 222 shape: torch.Size([10, 1, 4000]) Batch 223 shape: torch.Size([10, 1, 4000]) Batch 224 shape: torch.Size([10, 1, 4000]) Batch 225 shape: torch.Size([10, 1, 4000]) Batch 226 shape: torch.Size([10, 1, 4000]) Batch 227 shape: torch.Size([10, 1, 4000]) Batch 228 shape: torch.Size([10, 1, 4000]) Batch 229 shape: torch.Size([10, 1, 4000]) Batch 230 shape: torch.Size([10, 1, 4000]) Batch 231 shape: torch.Size([10, 1, 4000]) Batch 232 shape: torch.Size([10, 1, 4000]) Batch 233 shape: torch.Size([10, 1, 4000]) Batch 234 shape: torch.Size([10, 1, 4000]) Batch 235 shape: torch.Size([10, 1, 4000]) Batch 236 shape: torch.Size([10, 1, 4000]) Batch 237 shape: torch.Size([10, 1, 4000]) Batch 238 shape: torch.Size([10, 1, 4000]) Batch 239 shape: torch.Size([10, 1, 4000]) Batch 240 shape: torch.Size([10, 1, 4000]) Batch 241 shape: torch.Size([10, 1, 4000]) Batch 242 shape: torch.Size([10, 1, 4000]) Batch 243 shape: torch.Size([10, 1, 4000]) Batch 244 shape: torch.Size([10, 1, 4000]) Batch 245 shape: torch.Size([10, 1, 4000]) Batch 246 shape: torch.Size([10, 1, 4000]) Batch 247 shape: torch.Size([10, 1, 4000]) Batch 248 shape: torch.Size([10, 1, 4000]) Batch 249 shape: torch.Size([10, 1, 4000]) Batch 250 shape: torch.Size([10, 1, 4000]) Batch 251 shape: torch.Size([10, 1, 4000]) Batch 252 shape: torch.Size([10, 1, 4000]) Batch 253 shape: torch.Size([10, 1, 4000]) Batch 254 shape: torch.Size([10, 1, 4000]) Batch 255 shape: torch.Size([10, 1, 4000]) Batch 256 shape: torch.Size([10, 1, 4000]) Batch 257 shape: torch.Size([10, 1, 4000]) Batch 258 shape: torch.Size([10, 1, 4000]) Batch 259 shape: torch.Size([10, 1, 4000]) Batch 260 shape: torch.Size([10, 1, 4000]) Batch 261 shape: torch.Size([10, 1, 4000]) Batch 262 shape: torch.Size([10, 1, 4000]) Batch 263 shape: torch.Size([10, 1, 4000]) Batch 264 shape: torch.Size([10, 1, 4000]) Batch 265 shape: torch.Size([10, 1, 4000]) Batch 266 shape: torch.Size([10, 1, 4000]) Batch 267 shape: torch.Size([10, 1, 4000]) Batch 268 shape: torch.Size([10, 1, 4000]) Batch 269 shape: torch.Size([10, 1, 4000]) Batch 270 shape: torch.Size([10, 1, 4000]) Batch 271 shape: torch.Size([10, 1, 4000]) Batch 272 shape: torch.Size([10, 1, 4000]) Batch 273 shape: torch.Size([10, 1, 4000]) Batch 274 shape: torch.Size([10, 1, 4000]) Batch 275 shape: torch.Size([10, 1, 4000]) Batch 276 shape: torch.Size([10, 1, 4000]) Batch 277 shape: torch.Size([10, 1, 4000]) Batch 278 shape: torch.Size([10, 1, 4000]) Batch 279 shape: torch.Size([10, 1, 4000]) Batch 280 shape: torch.Size([10, 1, 4000]) Batch 281 shape: torch.Size([10, 1, 4000]) Batch 282 shape: torch.Size([10, 1, 4000]) Batch 283 shape: torch.Size([10, 1, 4000]) Batch 284 shape: torch.Size([10, 1, 4000]) Batch 285 shape: torch.Size([10, 1, 4000]) Batch 286 shape: torch.Size([10, 1, 4000]) Batch 287 shape: torch.Size([10, 1, 4000]) Batch 288 shape: torch.Size([10, 1, 4000]) Batch 289 shape: torch.Size([10, 1, 4000]) Batch 290 shape: torch.Size([10, 1, 4000]) Batch 291 shape: torch.Size([10, 1, 4000]) Batch 292 shape: torch.Size([10, 1, 4000]) Batch 293 shape: torch.Size([10, 1, 4000]) Batch 294 shape: torch.Size([10, 1, 4000]) Batch 295 shape: torch.Size([10, 1, 4000]) Batch 296 shape: torch.Size([10, 1, 4000]) Batch 297 shape: torch.Size([10, 1, 4000]) Batch 298 shape: torch.Size([10, 1, 4000]) Batch 299 shape: torch.Size([10, 1, 4000]) Batch 300 shape: torch.Size([10, 1, 4000]) Batch 301 shape: torch.Size([10, 1, 4000]) Batch 302 shape: torch.Size([10, 1, 4000]) Batch 303 shape: torch.Size([10, 1, 4000]) Batch 304 shape: torch.Size([10, 1, 4000]) Batch 305 shape: torch.Size([10, 1, 4000]) Batch 306 shape: torch.Size([10, 1, 4000]) Batch 307 shape: torch.Size([10, 1, 4000]) Batch 308 shape: torch.Size([10, 1, 4000]) Batch 309 shape: torch.Size([10, 1, 4000]) Batch 310 shape: torch.Size([10, 1, 4000]) Batch 311 shape: torch.Size([10, 1, 4000]) Batch 312 shape: torch.Size([10, 1, 4000]) Batch 313 shape: torch.Size([10, 1, 4000]) Batch 314 shape: torch.Size([10, 1, 4000]) Batch 315 shape: torch.Size([10, 1, 4000]) Batch 316 shape: torch.Size([10, 1, 4000]) Batch 317 shape: torch.Size([10, 1, 4000]) Batch 318 shape: torch.Size([10, 1, 4000]) Batch 319 shape: torch.Size([10, 1, 4000]) Batch 320 shape: torch.Size([10, 1, 4000]) Batch 321 shape: torch.Size([10, 1, 4000]) Batch 322 shape: torch.Size([10, 1, 4000]) Batch 323 shape: torch.Size([10, 1, 4000]) Batch 324 shape: torch.Size([10, 1, 4000]) Batch 325 shape: torch.Size([10, 1, 4000]) Batch 326 shape: torch.Size([10, 1, 4000]) Batch 327 shape: torch.Size([10, 1, 4000]) Batch 328 shape: torch.Size([10, 1, 4000]) Batch 329 shape: torch.Size([10, 1, 4000]) Batch 330 shape: torch.Size([10, 1, 4000]) Batch 331 shape: torch.Size([10, 1, 4000]) Batch 332 shape: torch.Size([10, 1, 4000]) Batch 333 shape: torch.Size([10, 1, 4000]) Batch 334 shape: torch.Size([10, 1, 4000]) Batch 335 shape: torch.Size([10, 1, 4000]) Batch 336 shape: torch.Size([10, 1, 4000]) Batch 337 shape: torch.Size([10, 1, 4000]) Batch 338 shape: torch.Size([10, 1, 4000]) Batch 339 shape: torch.Size([10, 1, 4000]) Batch 340 shape: torch.Size([10, 1, 4000]) Batch 341 shape: torch.Size([10, 1, 4000]) Batch 342 shape: torch.Size([10, 1, 4000]) Batch 343 shape: torch.Size([10, 1, 4000]) Batch 344 shape: torch.Size([10, 1, 4000]) Batch 345 shape: torch.Size([10, 1, 4000]) Batch 346 shape: torch.Size([10, 1, 4000]) Batch 347 shape: torch.Size([10, 1, 4000]) Batch 348 shape: torch.Size([10, 1, 4000]) Batch 349 shape: torch.Size([10, 1, 4000]) Batch 350 shape: torch.Size([10, 1, 4000]) Batch 351 shape: torch.Size([10, 1, 4000]) Batch 352 shape: torch.Size([10, 1, 4000]) Batch 353 shape: torch.Size([10, 1, 4000]) Batch 354 shape: torch.Size([10, 1, 4000]) Batch 355 shape: torch.Size([10, 1, 4000]) Batch 356 shape: torch.Size([10, 1, 4000]) Batch 357 shape: torch.Size([10, 1, 4000]) Batch 358 shape: torch.Size([10, 1, 4000]) Batch 359 shape: torch.Size([10, 1, 4000]) Batch 360 shape: torch.Size([10, 1, 4000]) Batch 361 shape: torch.Size([10, 1, 4000]) Batch 362 shape: torch.Size([10, 1, 4000]) Batch 363 shape: torch.Size([10, 1, 4000]) Batch 364 shape: torch.Size([10, 1, 4000]) Batch 365 shape: torch.Size([10, 1, 4000]) Batch 366 shape: torch.Size([10, 1, 4000]) Batch 367 shape: torch.Size([10, 1, 4000]) Batch 368 shape: torch.Size([10, 1, 4000]) Batch 369 shape: torch.Size([10, 1, 4000]) Batch 370 shape: torch.Size([10, 1, 4000]) Batch 371 shape: torch.Size([10, 1, 4000]) Batch 372 shape: torch.Size([10, 1, 4000]) Batch 373 shape: torch.Size([10, 1, 4000]) Batch 374 shape: torch.Size([10, 1, 4000]) Batch 375 shape: torch.Size([10, 1, 4000]) Batch 376 shape: torch.Size([10, 1, 4000]) Batch 377 shape: torch.Size([10, 1, 4000]) Batch 378 shape: torch.Size([10, 1, 4000]) Batch 379 shape: torch.Size([10, 1, 4000]) Batch 380 shape: torch.Size([10, 1, 4000]) Batch 381 shape: torch.Size([10, 1, 4000]) Batch 382 shape: torch.Size([10, 1, 4000]) Batch 383 shape: torch.Size([10, 1, 4000]) Batch 384 shape: torch.Size([10, 1, 4000]) Batch 385 shape: torch.Size([10, 1, 4000]) Batch 386 shape: torch.Size([10, 1, 4000]) Batch 387 shape: torch.Size([10, 1, 4000]) Batch 388 shape: torch.Size([10, 1, 4000]) Batch 389 shape: torch.Size([10, 1, 4000]) Batch 390 shape: torch.Size([10, 1, 4000]) Batch 391 shape: torch.Size([10, 1, 4000]) Batch 392 shape: torch.Size([10, 1, 4000]) Batch 393 shape: torch.Size([10, 1, 4000]) Batch 394 shape: torch.Size([10, 1, 4000]) Batch 395 shape: torch.Size([10, 1, 4000]) Batch 396 shape: torch.Size([10, 1, 4000]) Batch 397 shape: torch.Size([10, 1, 4000]) Batch 398 shape: torch.Size([10, 1, 4000]) Batch 399 shape: torch.Size([10, 1, 4000])
In [86]:
def exists(x):
return x is not None
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
self.dim = dim
self.theta = theta
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(self.theta) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
def Upsample(dim, dim_out):
return nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv1d(dim, dim_out, 3, padding=1)
)
def Downsample(dim, dim_out):
return nn.Conv1d(dim, dim_out, 4, 2, 1)
class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.g = nn.Parameter(torch.ones(1, dim, 1))
def forward(self, x):
return F.normalize(x, dim=1) * self.g * self.scale
class Block(nn.Module):
def __init__(self, dim, dim_out, dropout=0.):
super().__init__()
self.proj = nn.Conv1d(dim, dim_out, 3, padding=1)
self.norm = RMSNorm(dim_out)
self.act = nn.SiLU()
self.dropout = nn.Dropout(dropout)
def forward(self, x, scale_shift=None):
x = self.proj(x)
x = self.norm(x)
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift # scaling & shifting using time emb
x = self.act(x)
return self.dropout(x)
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim=None, dropout=0.):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None
self.block1 = Block(dim, dim_out, dropout=dropout)
self.block2 = Block(dim_out, dim_out)
self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
scale_shift = None
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb)
time_emb = time_emb.unsqueeze(2)
scale_shift = time_emb.chunk(2, dim=1)
h = self.block1(x, scale_shift=scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Sequential(
nn.Conv1d(hidden_dim, dim, 1),
RMSNorm(dim)
)
def forward(self, x):
B, C, L = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(lambda t: t.reshape(B, self.heads, -1, L), qkv)
q = q.softmax(dim=-2) # [B, h, d, L]
k = k.softmax(dim=-1) # [B, h, d, L]
# v.shape = [B, h, e, L]
q = q * self.scale
context = torch.matmul(k, v.transpose(-1,-2)) # [B, h, d, e]
out = torch.matmul(context.transpose(-1,-2), q) # [B, h, d, L]
out = out.reshape(B, -1, L)
return self.to_out(out)
class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
self.dim_head = dim_head
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv1d(hidden_dim, dim, 1)
def forward(self, x):
B, C, L = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(lambda t: t.reshape(B, self.heads, -1, L), qkv)
# q.shape = [B, h, d, I], k.shape = [B, h, d, L], v.shape = [B, h, d, L]
q = q * self.scale
sim = torch.matmul(q.transpose(-1,-2), k) # [B, h, I, L]
attn = sim.softmax(dim=-1)
out = torch.matmul(attn, v.transpose(-1,-2)) # [B, h, I, d]
out = out.reshape(B, -1, L)
return self.to_out(out)
class LinearAttentionBlock(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.norm = RMSNorm(dim)
self.fc = LinearAttention(dim, heads, dim_head)
def forward(self, x):
h = self.norm(x)
h = self.fc(h)
return h + x
class AttentionBlock(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.norm = RMSNorm(dim)
self.fc = Attention(dim, heads, dim_head)
def forward(self, x):
h = self.norm(x)
h = self.fc(h)
return h + x
class DownBlock(nn.Module):
def __init__(self, dim, dim_in, dim_out, heads=4, dim_head=32, dropout=0., last=False):
super().__init__()
self.time_dim = dim * 4
self.resnetblock1 = ResnetBlock(dim_in, dim_in, time_emb_dim=self.time_dim, dropout=dropout)
self.resnetblock2 = ResnetBlock(dim_in, dim_in, time_emb_dim=self.time_dim, dropout=dropout)
self.linattnblock = LinearAttentionBlock(dim_in, heads, dim_head)
self.downsample = Downsample(dim_in, dim_out) if not last else nn.Conv1d(dim_in, dim_out, 3, padding=1)
def forward(self, x, t):
h1 = self.resnetblock1(x, t)
h2 = self.resnetblock2(h1, t)
h2 = self.linattnblock(h2)
out = self.downsample(h2)
return out, h1, h2
class MidBlock(nn.Module):
def __init__(self, dim, mid_dim, heads=4, dim_head=32, dropout=0.):
super().__init__()
self.time_dim = dim * 4
self.resnetblock1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=self.time_dim, dropout=dropout)
self.attnblock = AttentionBlock(mid_dim, heads, dim_head)
self.resnetblock2 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=self.time_dim, dropout=dropout)
def forward(self, x, t):
x = self.resnetblock1(x, t)
x = self.attnblock(x)
return self.resnetblock2(x, t)
class UpBlock(nn.Module):
def __init__(self, dim, dim_in, dim_out, heads=4, dim_head=32, dropout=0., last=False):
super().__init__()
self.time_dim = dim * 4
self.resnetblock1 = ResnetBlock(dim_in + dim_out, dim_in, time_emb_dim=self.time_dim, dropout=dropout)
self.resnetblock2 = ResnetBlock(dim_in + dim_out, dim_in, time_emb_dim=self.time_dim, dropout=dropout)
self.linattnblock = LinearAttentionBlock(dim_in, heads, dim_head)
self.upsample = Upsample(dim_in, dim_out) if not last else nn.Conv1d(dim_in, dim_out, 3, padding=1)
def forward(self, x, h1, h2, t):
x = self.resnetblock1(torch.cat((x, h1), dim=1), t)
x = self.resnetblock2(torch.cat((x, h2), dim=1), t)
x = self.linattnblock(x)
return self.upsample(x)
class Unet(nn.Module):
def __init__(self, dim=16):
super(Unet, self).__init__()
time_dim = dim * 4
self.init_conv = nn.Conv1d(1, dim, 7, padding=3)
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
self.down1 = DownBlock(dim, dim_in = dim, dim_out = dim*2)
self.down2 = DownBlock(dim, dim_in = dim*2, dim_out = dim*4)
self.down3 = DownBlock(dim, dim_in = dim*4, dim_out = dim*8)
self.down4 = DownBlock(dim, dim_in = dim*8, dim_out = dim*16, last=True)
self.mid = MidBlock(dim, mid_dim = dim*16)
self.up1 = UpBlock(dim, dim_in = dim*16, dim_out = dim*8)
self.up2 = UpBlock(dim, dim_in = dim*8, dim_out = dim*4)
self.up3 = UpBlock(dim, dim_in = dim*4, dim_out = dim*2)
self.up4 = UpBlock(dim, dim_in = dim*2, dim_out = dim, last=True)
self.final_res_block = ResnetBlock(dim*2, dim)
self.final_conv = nn.Conv1d(dim, 1, 1)
def forward(self, x, time):
r = self.init_conv(x)
t = self.time_mlp(time)
x, h1, h2 = self.down1(r, t)
x, h3, h4 = self.down2(x, t)
x, h5, h6 = self.down3(x, t)
x, h7, h8 = self.down4(x, t)
x = self.mid(x, t)
x = self.up1(x, h7, h8, t)
x = self.up2(x, h5, h6, t)
x = self.up3(x, h3, h4, t)
x = self.up4(x, h1, h2, t)
x = torch.cat((x, r), dim=1)
x = self.final_res_block(x, t)
return self.final_conv(x)
In [87]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
In [88]:
def train_ddpm(model, dataloader, optimizer, epochs=50):
model.train()
model.to(device)
for epoch in range(epochs):
total_loss = 0
for step, x0 in enumerate(dataloader):
# x0.shape = [batch_size, 1, signal_length]
x0 = x0.to(device)
# 1) t를 uniform random으로 뽑기
# t range: [1, T], 실제 구현에서는 [0, T-1]도 가능
batch_size = x0.shape[0]
t = torch.randint(
low=1, high=T+1, size=(batch_size,), device=device
) # t in [1..T]
# 2) 노이즈 epsilon 샘플
epsilon = torch.randn_like(x0)
# 3) alpha_bar_t 가져오기
# t-1 인덱스로 indexing (파이썬은 0-based, t는 1-based)
alpha_bar_t = alpha_bar.to(device)[t-1].reshape(batch_size, 1, 1)
# 4) x_t 생성
# x_t = sqrt(alpha_bar[t]) * x0 + sqrt(1 - alpha_bar[t]) * epsilon
sqrt_alpha_bar_t = torch.sqrt(alpha_bar_t)
sqrt_one_minus_alpha_bar_t = torch.sqrt(1. - alpha_bar_t)
x_t = sqrt_alpha_bar_t * x0 + sqrt_one_minus_alpha_bar_t * epsilon
# 5) 모델 추론: 모델은 x_t와 t를 입력받아 epsilon 예측
epsilon_pred = model(x_t, t) # shape: same as x0
# 6) MSE Loss
loss = F.mse_loss(epsilon_pred, epsilon)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch [{epoch+1}/{epochs}] - Loss: {total_loss / len(dataloader):.4f}")
In [89]:
model = Unet()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
# 훈련
train_ddpm(model, dataloader, optimizer, epochs=200)
Epoch [1/200] - Loss: 0.1021 Epoch [2/200] - Loss: 0.0219 Epoch [3/200] - Loss: 0.0155 Epoch [4/200] - Loss: 0.0118 Epoch [5/200] - Loss: 0.0106 Epoch [6/200] - Loss: 0.0095 Epoch [7/200] - Loss: 0.0093 Epoch [8/200] - Loss: 0.0090 Epoch [9/200] - Loss: 0.0081 Epoch [10/200] - Loss: 0.0094 Epoch [11/200] - Loss: 0.0080 Epoch [12/200] - Loss: 0.0077 Epoch [13/200] - Loss: 0.0075 Epoch [14/200] - Loss: 0.0071 Epoch [15/200] - Loss: 0.0075 Epoch [16/200] - Loss: 0.0071 Epoch [17/200] - Loss: 0.0093 Epoch [18/200] - Loss: 0.0067 Epoch [19/200] - Loss: 0.0065 Epoch [20/200] - Loss: 0.0067 Epoch [21/200] - Loss: 0.0065 Epoch [22/200] - Loss: 0.0065 Epoch [23/200] - Loss: 0.0065 Epoch [24/200] - Loss: 0.0072 Epoch [25/200] - Loss: 0.0062 Epoch [26/200] - Loss: 0.0062 Epoch [27/200] - Loss: 0.0063 Epoch [28/200] - Loss: 0.0061 Epoch [29/200] - Loss: 0.0061 Epoch [30/200] - Loss: 0.0059 Epoch [31/200] - Loss: 0.0058 Epoch [32/200] - Loss: 0.0063 Epoch [33/200] - Loss: 0.0064 Epoch [34/200] - Loss: 0.0057 Epoch [35/200] - Loss: 0.0057 Epoch [36/200] - Loss: 0.0055 Epoch [37/200] - Loss: 0.0056 Epoch [38/200] - Loss: 0.0058 Epoch [39/200] - Loss: 0.0062 Epoch [40/200] - Loss: 0.0054 Epoch [41/200] - Loss: 0.0054 Epoch [42/200] - Loss: 0.0057 Epoch [43/200] - Loss: 0.0054 Epoch [44/200] - Loss: 0.0052 Epoch [45/200] - Loss: 0.0054 Epoch [46/200] - Loss: 0.0052 Epoch [47/200] - Loss: 0.0051 Epoch [48/200] - Loss: 0.0050 Epoch [49/200] - Loss: 0.0050 Epoch [50/200] - Loss: 0.0051 Epoch [51/200] - Loss: 0.0051 Epoch [52/200] - Loss: 0.0050 Epoch [53/200] - Loss: 0.0052 Epoch [54/200] - Loss: 0.0049 Epoch [55/200] - Loss: 0.0047 Epoch [56/200] - Loss: 0.0048 Epoch [57/200] - Loss: 0.0048 Epoch [58/200] - Loss: 0.0047 Epoch [59/200] - Loss: 0.0049 Epoch [60/200] - Loss: 0.0045 Epoch [61/200] - Loss: 0.0047 Epoch [62/200] - Loss: 0.0045 Epoch [63/200] - Loss: 0.0047 Epoch [64/200] - Loss: 0.0045 Epoch [65/200] - Loss: 0.0042 Epoch [66/200] - Loss: 0.0044 Epoch [67/200] - Loss: 0.0044 Epoch [68/200] - Loss: 0.0042 Epoch [69/200] - Loss: 0.0044 Epoch [70/200] - Loss: 0.0043 Epoch [71/200] - Loss: 0.0041 Epoch [72/200] - Loss: 0.0041 Epoch [73/200] - Loss: 0.0041 Epoch [74/200] - Loss: 0.0040 Epoch [75/200] - Loss: 0.0039 Epoch [76/200] - Loss: 0.0040 Epoch [77/200] - Loss: 0.0040 Epoch [78/200] - Loss: 0.0043 Epoch [79/200] - Loss: 0.0038 Epoch [80/200] - Loss: 0.0038 Epoch [81/200] - Loss: 0.0037 Epoch [82/200] - Loss: 0.0037 Epoch [83/200] - Loss: 0.0037 Epoch [84/200] - Loss: 0.0036 Epoch [85/200] - Loss: 0.0036 Epoch [86/200] - Loss: 0.0035 Epoch [87/200] - Loss: 0.0034 Epoch [88/200] - Loss: 0.0033 Epoch [89/200] - Loss: 0.0034 Epoch [90/200] - Loss: 0.0032 Epoch [91/200] - Loss: 0.0032 Epoch [92/200] - Loss: 0.0032 Epoch [93/200] - Loss: 0.0031 Epoch [94/200] - Loss: 0.0030 Epoch [95/200] - Loss: 0.0030 Epoch [96/200] - Loss: 0.0029 Epoch [97/200] - Loss: 0.0030 Epoch [98/200] - Loss: 0.0029 Epoch [99/200] - Loss: 0.0029 Epoch [100/200] - Loss: 0.0029 Epoch [101/200] - Loss: 0.0060 Epoch [102/200] - Loss: 0.0028 Epoch [103/200] - Loss: 0.0028 Epoch [104/200] - Loss: 0.0027 Epoch [105/200] - Loss: 0.0028 Epoch [106/200] - Loss: 0.0027 Epoch [107/200] - Loss: 0.0028 Epoch [108/200] - Loss: 0.0026 Epoch [109/200] - Loss: 0.0027 Epoch [110/200] - Loss: 0.0026 Epoch [111/200] - Loss: 0.0025 Epoch [112/200] - Loss: 0.0026 Epoch [113/200] - Loss: 0.0026 Epoch [114/200] - Loss: 0.0025 Epoch [115/200] - Loss: 0.0025 Epoch [116/200] - Loss: 0.0025 Epoch [117/200] - Loss: 0.0026 Epoch [118/200] - Loss: 0.0025 Epoch [119/200] - Loss: 0.0025 Epoch [120/200] - Loss: 0.0025 Epoch [121/200] - Loss: 0.0025 Epoch [122/200] - Loss: 0.0025 Epoch [123/200] - Loss: 0.0024 Epoch [124/200] - Loss: 0.0025 Epoch [125/200] - Loss: 0.0024 Epoch [126/200] - Loss: 0.0024 Epoch [127/200] - Loss: 0.0023 Epoch [128/200] - Loss: 0.0023 Epoch [129/200] - Loss: 0.0024 Epoch [130/200] - Loss: 0.0023 Epoch [131/200] - Loss: 0.0024 Epoch [132/200] - Loss: 0.0022 Epoch [133/200] - Loss: 0.0022 Epoch [134/200] - Loss: 0.0023 Epoch [135/200] - Loss: 0.0023 Epoch [136/200] - Loss: 0.0022 Epoch [137/200] - Loss: 0.0023 Epoch [138/200] - Loss: 0.0022 Epoch [139/200] - Loss: 0.0022 Epoch [140/200] - Loss: 0.0022 Epoch [141/200] - Loss: 0.0022 Epoch [142/200] - Loss: 0.0021 Epoch [143/200] - Loss: 0.0021 Epoch [144/200] - Loss: 0.0021 Epoch [145/200] - Loss: 0.0021 Epoch [146/200] - Loss: 0.0022 Epoch [147/200] - Loss: 0.0021 Epoch [148/200] - Loss: 0.0020 Epoch [149/200] - Loss: 0.0021 Epoch [150/200] - Loss: 0.0021 Epoch [151/200] - Loss: 0.0021 Epoch [152/200] - Loss: 0.0021 Epoch [153/200] - Loss: 0.0021 Epoch [154/200] - Loss: 0.0023 Epoch [155/200] - Loss: 0.0023 Epoch [156/200] - Loss: 0.0020 Epoch [157/200] - Loss: 0.0020 Epoch [158/200] - Loss: 0.0020 Epoch [159/200] - Loss: 0.0020 Epoch [160/200] - Loss: 0.0019 Epoch [161/200] - Loss: 0.0021 Epoch [162/200] - Loss: 0.0019 Epoch [163/200] - Loss: 0.0019 Epoch [164/200] - Loss: 0.0019 Epoch [165/200] - Loss: 0.0020 Epoch [166/200] - Loss: 0.0023 Epoch [167/200] - Loss: 0.0019 Epoch [168/200] - Loss: 0.0019 Epoch [169/200] - Loss: 0.0021 Epoch [170/200] - Loss: 0.0018 Epoch [171/200] - Loss: 0.0019 Epoch [172/200] - Loss: 0.0018 Epoch [173/200] - Loss: 0.0019 Epoch [174/200] - Loss: 0.0019 Epoch [175/200] - Loss: 0.0019 Epoch [176/200] - Loss: 0.0018 Epoch [177/200] - Loss: 0.0018 Epoch [178/200] - Loss: 0.0018 Epoch [179/200] - Loss: 0.0019 Epoch [180/200] - Loss: 0.0019 Epoch [181/200] - Loss: 0.0018 Epoch [182/200] - Loss: 0.0018 Epoch [183/200] - Loss: 0.0018 Epoch [184/200] - Loss: 0.0018 Epoch [185/200] - Loss: 0.0019 Epoch [186/200] - Loss: 0.0019 Epoch [187/200] - Loss: 0.0024 Epoch [188/200] - Loss: 0.0018 Epoch [189/200] - Loss: 0.0017 Epoch [190/200] - Loss: 0.0018 Epoch [191/200] - Loss: 0.0017 Epoch [192/200] - Loss: 0.0018 Epoch [193/200] - Loss: 0.0018 Epoch [194/200] - Loss: 0.0018 Epoch [195/200] - Loss: 0.0017 Epoch [196/200] - Loss: 0.0018 Epoch [197/200] - Loss: 0.0018 Epoch [198/200] - Loss: 0.0018 Epoch [199/200] - Loss: 0.0017 Epoch [200/200] - Loss: 0.0018
DDPM Sampling¶
In [106]:
@torch.no_grad()
def sample_ddpm(model, num_samples=1, signal_length=10*fs):
"""
DDPM reverse diffusion sampling
"""
model.eval()
model.to(device)
# 1) x_T ~ N(0, I)
x_t = torch.randn(num_samples, 1, signal_length).to(device)
for i in reversed(range(T)): # i: T-1 down to 0
# i는 파이썬 인덱스, 실제 t는 i+1
t_val = torch.tensor([i+1]*num_samples, device=device) # shape = [num_samples]
if (i+1)%100 == 0 :
plt.figure(figsize=(120,5))
for j in range(num_samples):
signal = x_t[j,0].to('cpu').numpy()
plt.subplot(1, 10, j+1)
plt.title(f'X_{(i+1)} Signal')
plt.plot(signal)
plt.ylim(-4, 4)
plt.show()
# 모델의 예측 노이즈
eps = model(x_t, t_val)
sigma_t = torch.sqrt(beta[i])
z = torch.randn(num_samples, 1, signal_length).to(device)
if i == 0: sigma_t = 0
alpha_t = alpha[i]
alpha_bar_t = alpha_bar[i]
# (주의) alpha_bar[i]는 t=i+1에 해당
# 2) 역방향 공식
# x_{t-1} = 1/sqrt(alpha_t) ( x_t - (1-alpha_t)/sqrt(1-alpha_bar_t)* eps )
# + sigma_t * z (if we add noise)
# 여기서는 단순하게 sigma=0 가정
one_over_sqrt_alpha_t = 1.0 / torch.sqrt(alpha_t)
one_minus_alpha_t = 1.0 - alpha_t
sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)
# reshape to match (batch_size, 1, 1)
one_over_sqrt_alpha_t = one_over_sqrt_alpha_t.view(1,1,1).to(device)
one_minus_alpha_t = one_minus_alpha_t.view(1,1,1).to(device)
sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.view(1,1,1).to(device)
x_prev = one_over_sqrt_alpha_t * (x_t - (one_minus_alpha_t / sqrt_one_minus_alpha_bar_t) * eps) + sigma_t * z
x_t = x_prev # update
# 최종 x_0 반환
return x_t
In [107]:
# 샘플링 예시
num_samples = 4
samples = sample_ddpm(model, num_samples=num_samples, signal_length=1*fs)
t = torch.linspace(0, 1, 1*fs)
print("samples shape:", samples.shape)
for i in range(num_samples):
sample = samples[i,0].to('cpu')
w = torch.linspace(-fs/2, fs/2, 1*fs)
fft = torch.fft.fft(sample)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,5))
plt.subplot(1, 2, 1)
plt.plot(t, sample)
plt.ylim(-4, 4)
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()
plt.show()
samples shape: torch.Size([4, 1, 4000])
$$ X_t = \sqrt{\bar{\alpha}_t}X_0 + \sqrt{1-\bar{\alpha}_t}\varepsilon \;\;\; where \;\; \varepsilon \sim N(0,1) $$
In [108]:
@torch.no_grad()
def partial_diffusion(model, X0, Lambda=200, num_samples=1, signal_length=10*fs):
"""
DDPM reverse partial diffusion sampling
Lambda: partial diffusion time
"""
model.eval()
model.to(device)
# 1) x_t = sqrt(alpha_bar_t)*X_0 + sqrt(1 - alpha_bar_t)*eps
sqrt_alpha_bar_Lambda = torch.sqrt(alpha_bar[Lambda-1]).view(1,1,1).to(device)
sqrt_one_minus_alpha_bar_Lambda = torch.sqrt(1.0 - alpha_bar[Lambda-1]).view(1,1,1).to(device)
eps = torch.randn(num_samples, 1, signal_length).to(device)
x_t = sqrt_alpha_bar_Lambda @ X0 + sqrt_one_minus_alpha_bar_Lambda @ eps
for i in reversed(range(Lambda)): # i: Lambda-1 down to 0
# i는 파이썬 인덱스, 실제 t는 i+1
t_val = torch.tensor([i+1]*num_samples, device=device) # shape = [num_samples]
if (i+1) % (Lambda//10) == 0 :
plt.figure(figsize=(120,5))
for j in range(num_samples):
signal = x_t[j,0].to('cpu').numpy()
plt.subplot(1, 10, j+1)
plt.title(f'X_{(i+1)} Signal')
plt.plot(signal)
plt.ylim(-4, 4)
plt.show()
# 모델의 예측 노이즈
eps = model(x_t, t_val)
sigma_t = torch.sqrt(beta[i])
z = torch.randn(num_samples, 1, signal_length).to(device)
if i == 0: sigma_t = 0
alpha_t = alpha[i]
alpha_bar_t = alpha_bar[i]
# (주의) alpha_bar[i]는 t=i+1에 해당
# 2) 역방향 공식
# x_{t-1} = 1/sqrt(alpha_t) ( x_t - (1-alpha_t)/sqrt(1-alpha_bar_t)* eps )
# + sigma_t * z (if we add noise)
one_over_sqrt_alpha_t = 1.0 / torch.sqrt(alpha_t)
one_minus_alpha_t = 1.0 - alpha_t
sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)
# reshape to match (batch_size, 1, 1)
one_over_sqrt_alpha_t = one_over_sqrt_alpha_t.view(1,1,1).to(device)
one_minus_alpha_t = one_minus_alpha_t.view(1,1,1).to(device)
sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.view(1,1,1).to(device)
x_prev = one_over_sqrt_alpha_t * (x_t - (one_minus_alpha_t / sqrt_one_minus_alpha_bar_t) * eps) + sigma_t * z
x_t = x_prev # update
# 최종 x_0 반환
return x_t
In [109]:
# partial diffusion 예시
num_samples = 1
Lambda = 200
fs = 4000
t = torch.linspace(0, 1, 1*fs)
original = torch.sin(2 * torch.pi * t).reshape(1, 1, -1).to(device)
samples = partial_diffusion(model, original, Lambda=Lambda, num_samples=num_samples, signal_length=1*fs)
print("samples shape:", samples.shape) # [num_samples, 1, 5*fs]
org_signal = original.reshape(-1).to('cpu')
w = torch.linspace(-fs/2, fs/2, 1*fs)
fft = torch.fft.fft(org_signal)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,5))
plt.subplot(1, 2, 1)
plt.plot(t, org_signal)
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()
plt.show()
for i in range(num_samples):
sample = samples[i,0].to('cpu')
fft = torch.fft.fft(sample)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,10))
plt.subplot(2, 2, 1)
plt.plot(t, sample)
plt.grid()
plt.subplot(2, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()
plt.subplot(2, 2, 3)
plt.plot(t, sample - org_signal)
plt.grid()
res_fft = torch.fft.fft(sample - org_signal)
res_fft_power = torch.abs(torch.fft.fftshift(res_fft))
plt.subplot(2, 2, 4)
plt.plot(w, res_fft_power)
plt.xlim(0, 120)
plt.grid()
plt.show()
samples shape: torch.Size([1, 1, 4000])
In [110]:
# partial diffusion 예시
num_samples = 1
Lambda = 200
t = torch.linspace(0, 1, 1*fs)
s1 = torch.sin(10 * torch.pi * t).reshape(1, 1, -1)
s2 = torch.sin(210 * torch.pi * t).reshape(1, 1, -1)
original = (s1 + s2).to(device)
samples = partial_diffusion(model, original, Lambda=Lambda, num_samples=num_samples, signal_length=1*fs)
print("samples shape:", samples.shape) # [num_samples, 1, 5*fs]
org_signal = original.reshape(-1).to('cpu')
w = torch.linspace(-fs/2, fs/2, 1*fs)
fft = torch.fft.fft(org_signal)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,5))
plt.subplot(1, 2, 1)
plt.plot(t, org_signal)
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()
plt.show()
for i in range(num_samples):
sample = samples[i,0].to('cpu')
fft = torch.fft.fft(sample)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,10))
plt.subplot(2, 2, 1)
plt.plot(t, sample)
plt.grid()
plt.subplot(2, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()
plt.subplot(2, 2, 3)
plt.plot(t, sample - org_signal)
plt.grid()
res_fft = torch.fft.fft(sample - org_signal)
res_fft_power = torch.abs(torch.fft.fftshift(res_fft))
plt.subplot(2, 2, 4)
plt.plot(w, res_fft_power)
plt.xlim(0, 120)
plt.grid()
plt.show()
samples shape: torch.Size([1, 1, 4000])
In [124]:
from scipy.stats import norm
In [133]:
# partial diffusion 예시
num_samples = 1
Lambda = 200
t = torch.linspace(0, 1, 1*fs)
s1 = torch.sin(10 * torch.pi * t).reshape(1, 1, -1)
s2 = torch.sin(210 * torch.pi * t).reshape(1, 1, -1)
noise = 1 * torch.randn(t.shape)
original = (s1 + s2 + noise).to(device)
samples = partial_diffusion(model, original, Lambda=Lambda, num_samples=num_samples, signal_length=1*fs)
print("samples shape:", samples.shape) # [num_samples, 1, 10*fs]
org_signal = original.reshape(-1).to('cpu')
w = torch.linspace(-fs/2, fs/2, 1*fs)
fft = torch.fft.fft(org_signal)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,5))
plt.subplot(1, 2, 1)
plt.plot(t, org_signal)
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()
plt.show()
for i in range(num_samples):
sample = samples[i,0].to('cpu')
fft = torch.fft.fft(sample)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,10))
plt.subplot(2, 2, 1)
plt.plot(t, sample)
plt.grid()
plt.subplot(2, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()
residual = sample - org_signal
plt.subplot(2, 2, 3)
plt.plot(t, residual)
plt.grid()
res_fft = torch.fft.fft(residual)
res_fft_power = torch.abs(torch.fft.fftshift(res_fft))
plt.subplot(2, 2, 4)
plt.plot(w, res_fft_power)
plt.xlim(0, 120)
plt.grid()
plt.show()
plt.figure(figsize=(12, 8))
gaus_x = torch.arange(-4, 4, 0.001)
plt.title(f'mean: {torch.mean(residual)}, std: {torch.std(residual)}')
plt.hist(residual, bins=200)
plt.plot(gaus_x, 130*norm.pdf(gaus_x, 0, 1))
plt.show()
samples shape: torch.Size([1, 1, 4000])
In [134]:
# partial diffusion 예시
num_samples = 1
Lambda = 400
t = torch.linspace(0, 1, 1*fs)
s1 = torch.sin(10 * torch.pi * t).reshape(1, 1, -1)
s2 = torch.sin(210 * torch.pi * t).reshape(1, 1, -1)
noise = 1 * torch.randn(t.shape)
original = (s1 + s2 + noise).to(device)
samples = partial_diffusion(model, original, Lambda=Lambda, num_samples=num_samples, signal_length=1*fs)
print("samples shape:", samples.shape) # [num_samples, 1, 10*fs]
org_signal = original.reshape(-1).to('cpu')
w = torch.linspace(-fs/2, fs/2, 1*fs)
fft = torch.fft.fft(org_signal)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,5))
plt.subplot(1, 2, 1)
plt.plot(t, org_signal)
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()
plt.show()
for i in range(num_samples):
sample = samples[i,0].to('cpu')
fft = torch.fft.fft(sample)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,10))
plt.subplot(2, 2, 1)
plt.plot(t, sample)
plt.grid()
plt.subplot(2, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()
residual = sample - org_signal
plt.subplot(2, 2, 3)
plt.plot(t, residual)
plt.grid()
res_fft = torch.fft.fft(residual)
res_fft_power = torch.abs(torch.fft.fftshift(res_fft))
plt.subplot(2, 2, 4)
plt.plot(w, res_fft_power)
plt.xlim(0, 120)
plt.grid()
plt.show()
plt.figure(figsize=(12, 8))
gaus_x = torch.arange(-4, 4, 0.001)
plt.title(f'mean: {torch.mean(residual)}, std: {torch.std(residual)}')
plt.hist(residual, bins=200)
plt.plot(gaus_x, 130*norm.pdf(gaus_x, 0, 1))
plt.show()
samples shape: torch.Size([1, 1, 4000])
In [135]:
# partial diffusion 예시
num_samples = 1
Lambda = 600
t = torch.linspace(0, 1, 1*fs)
s1 = torch.sin(10 * torch.pi * t).reshape(1, 1, -1)
s2 = torch.sin(210 * torch.pi * t).reshape(1, 1, -1)
noise = 1 * torch.randn(t.shape)
original = (s1 + s2 + noise).to(device)
samples = partial_diffusion(model, original, Lambda=Lambda, num_samples=num_samples, signal_length=1*fs)
print("samples shape:", samples.shape) # [num_samples, 1, 10*fs]
org_signal = original.reshape(-1).to('cpu')
w = torch.linspace(-fs/2, fs/2, 1*fs)
fft = torch.fft.fft(org_signal)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,5))
plt.subplot(1, 2, 1)
plt.plot(t, org_signal)
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()
plt.show()
for i in range(num_samples):
sample = samples[i,0].to('cpu')
fft = torch.fft.fft(sample)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,10))
plt.subplot(2, 2, 1)
plt.plot(t, sample)
plt.grid()
plt.subplot(2, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()
residual = sample - org_signal
plt.subplot(2, 2, 3)
plt.plot(t, residual)
plt.grid()
res_fft = torch.fft.fft(residual)
res_fft_power = torch.abs(torch.fft.fftshift(res_fft))
plt.subplot(2, 2, 4)
plt.plot(w, res_fft_power)
plt.xlim(0, 120)
plt.grid()
plt.show()
plt.figure(figsize=(12, 8))
gaus_x = torch.arange(-4, 4, 0.001)
plt.title(f'mean: {torch.mean(residual)}, std: {torch.std(residual)}')
plt.hist(residual, bins=200)
plt.plot(gaus_x, 130*norm.pdf(gaus_x, 0, 1))
plt.show()
samples shape: torch.Size([1, 1, 4000])
In [117]:
# 샘플링 예시
num_samples = 8
samples = sample_ddpm(model, num_samples=num_samples, signal_length=2*fs)
t = torch.linspace(0, 2, 2*fs)
print("samples shape:", samples.shape)
for i in range(num_samples):
sample = samples[i,0].to('cpu')
w = torch.linspace(-fs/2, fs/2, 2*fs)
fft = torch.fft.fft(sample)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,5))
plt.subplot(1, 2, 1)
plt.plot(t, sample)
plt.ylim(-4, 4)
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 120)
plt.grid()
plt.show()
samples shape: torch.Size([8, 1, 8000])
DDIM Sampling¶
In [114]:
@torch.no_grad()
def sample_ddim(model, num_samples=1, signal_length=10*fs):
"""
DDPM reverse diffusion sampling
"""
model.eval()
model.to(device)
# 1) x_T ~ N(0, I)
x_t = torch.randn(num_samples, 1, signal_length).to(device)
for i in reversed(range(T)): # i: T-1 down to 0
# i는 파이썬 인덱스, 실제 t는 i+1
t_val = torch.tensor([i+1]*num_samples, device=device) # shape = [num_samples]
if (i+1)%100 == 0 :
plt.figure(figsize=(120,5))
for j in range(num_samples):
signal = x_t[j,0].to('cpu').numpy()
plt.subplot(1, 10, j+1)
plt.title(f'X_{(i+1)} Signal')
plt.plot(signal)
plt.ylim(-4, 4)
plt.show()
# 모델의 예측 노이즈
eps = model(x_t, t_val)
z = torch.randn(num_samples, 1, signal_length).to(device)
# (주의) alpha_bar[i]는 t=i+1에 해당
alpha_t = alpha[i]
alpha_bar_t = alpha_bar[i]
sigma_t = torch.sqrt((1.0 - alpha_t) * (1.0 - alpha_bar_t/alpha_t) / (1.0 - alpha_bar_t))
# 역방향 공식 (DDIM)
# x_{t-1} = 1/sqrt(alpha_t) (x_t - sqrt(1-alpha_bar_t)*eps) + sqrt(1 - alpha_bar_t/alpha_t - sigma_t^2)*eps + sigma_t * z
one_over_sqrt_alpha_t = 1.0 / torch.sqrt(alpha_t)
sqrt_one_minus_alpha_bar_t = torch.sqrt(1.0 - alpha_bar_t)
sqrt_one_minus_alpha_t_minus_1_minus_sigma_t_square = torch.sqrt(1.0 - alpha_bar_t/alpha_t - sigma_t**2)
if i != 0:
x_prev = one_over_sqrt_alpha_t * (x_t - sqrt_one_minus_alpha_bar_t*eps) + sqrt_one_minus_alpha_t_minus_1_minus_sigma_t_square*eps + sigma_t*z
# x_1 -> x_0
if i == 0:
# 역방향 공식 (DDPM)
# x_{t-1} = 1/sqrt(alpha_t) (x_t - (1-alpha_t)/sqrt(1-alpha_bar_t)*eps)
one_minus_alpha_t = 1.0 - alpha_t
# reshape to match (batch_size, 1, 1)
one_over_sqrt_alpha_t = one_over_sqrt_alpha_t.to(device)
one_minus_alpha_t = one_minus_alpha_t.to(device)
sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t.to(device)
x_prev = one_over_sqrt_alpha_t * (x_t - (one_minus_alpha_t / sqrt_one_minus_alpha_bar_t) * eps)
x_t = x_prev # update
# 최종 x_0 반환
return x_t
In [115]:
# 샘플링 예시
num_samples = 5
samples = sample_ddim(model, num_samples=num_samples, signal_length=10*fs)
t = torch.linspace(0, 10, 10*fs)
print("samples shape:", samples.shape)
for i in range(num_samples):
sample = samples[i,0].to('cpu')
w = torch.linspace(-fs/2, fs/2, 10*fs)
fft = torch.fft.fft(sample)
fft_power = torch.abs(torch.fft.fftshift(fft))
plt.figure(figsize=(25,5))
plt.subplot(1, 2, 1)
plt.plot(t, sample)
plt.ylim(-4, 4)
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(w, fft_power)
plt.xlim(0, 20)
plt.grid()
plt.show()
samples shape: torch.Size([5, 1, 40000])
In [ ]: